package com.auditlog.sql.runner;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.auditlog.datasource.StatementProxy;
import com.auditlog.datasource.db.SqlType;
import com.auditlog.datasource.struct.RecordImage;
import com.auditlog.datasource.struct.StatementMetaData;
import com.auditlog.datasource.struct.TableRecord;
import com.auditlog.datasource.struct.TupleTwo;
import com.auditlog.datasource.table.TableMeta;
import com.auditlog.util.IndexUtils;
import net.sf.jsqlparser.statement.Statement;

import java.sql.SQLException;
import java.util.*;
import java.util.stream.Collectors;

public class SqlRunnerChain {

    private final List<TupleTwo<Statement, SqlRunner>> tuples;

    private Map<String, RecordImage> recordImageMap = new LinkedHashMap<>();

    /**
     * 初始化一个执行器，保证构建key时的delimiter是一致的
     */
    private final AbstractSqlRunner defaultSqlRunner;

    private String delimiter = "|`$`|";

    public SqlRunnerChain(StatementProxy statementProxy, List<TupleTwo<Statement, SqlRunner>> tuples) {
        this.tuples = tuples;
        defaultSqlRunner = new DefaultSqlRunner(statementProxy, null);
    }

    /**
     * 前处理
     *
     * @return: void
     */
    public void beforeExecute() throws SQLException {
        List<StatementMetaData> metaDataList = tuples.stream().map(tuple -> {
            Statement v1 = tuple.getV1();
            SqlRunner v2 = tuple.getV2();
            v2.validate();
            return v2.getMetaData(v1);
        }).collect(Collectors.toList());
        Map<String, Map<SqlType, List<StatementMetaData>>> metaMap = metaDataList.stream().filter(metaData -> StrUtil.isNotEmpty(metaData.getTableName())).collect(Collectors.groupingBy(metaData -> metaData.getTableName().toUpperCase(Locale.ROOT), Collectors.groupingBy(metaData -> metaData.getSqlType(), Collectors.toList())));
        Set<String> tableSet = metaMap.keySet();
        for (String tableName : tableSet) {
            Map<SqlType, List<StatementMetaData>> sqlTypeListMap = metaMap.get(tableName);
            Set<SqlType> sqlTypes = sqlTypeListMap.keySet();
            for (SqlType sqlType : sqlTypes) {
                List<StatementMetaData> statementMetaData = sqlTypeListMap.get(sqlType);
                // 这里其实在查询数据的时候可以将UPDATE与DELETE合并，但是为了区分查出的数据到底是delete或者update，所以还是分开查询的（当然可以通过加字段标识，但是对项目改造比较大，以后优化吧）
                if (sqlType == SqlType.INSERT) {
                    handleInsert(tableName, statementMetaData);
                } else if (sqlType == SqlType.UPDATE) {
                    handleDeleteOrUpdate(tableName, SqlType.UPDATE, statementMetaData);
                } else if (sqlType == SqlType.DELETE) {
                    handleDeleteOrUpdate(tableName, SqlType.DELETE, statementMetaData);
                }
            }
        }
        // validNotContainSameRecord();
    }


    /**
     * 后处理
     *
     * @return: void
     */
    public void afterExecute() throws SQLException {
        Map<String, List<RecordImage>> tableRecordImage = this.recordImageMap.values().stream().collect(Collectors.groupingBy(recordImage -> recordImage.getTableName().toUpperCase(Locale.ROOT), Collectors.toList()));
        Map<String, TableRecord> tableRecordMap = new HashMap<>();
        for (Map.Entry<String, List<RecordImage>> tableEntry : tableRecordImage.entrySet()) {
            List<RecordImage> tableEntryValue = tableEntry.getValue();
            List<List<Object>> pks = new ArrayList<>();
            List<String> pkNames = null;
            List<String> selectColumns = null;
            for (RecordImage recordImage : tableEntryValue) {
                if (pkNames == null) {
                    pkNames = recordImage.getPkColumnsName();
                    selectColumns = recordImage.getAllColumnsWithAlias();
                }
                if (CollectionUtil.isNotEmpty(recordImage.getInsertPks())) {
                    pks.addAll(recordImage.getInsertPks());
                }
                if (CollectionUtil.isNotEmpty(recordImage.getPossibleUpdateOrDeletePks())) {
                    pks.addAll(recordImage.getPossibleUpdateOrDeletePks());
                }
            }
            if (CollectionUtil.isNotEmpty(pks)) {
                TableRecord tableRecord = defaultSqlRunner.getTableRecord(tableEntry.getKey(), pkNames, pks, selectColumns);
                tableRecordMap.put(tableEntry.getKey(), tableRecord);
            }
        }
        for (Map.Entry<String, RecordImage> recordImageEntry : this.recordImageMap.entrySet()) {
            RecordImage value = recordImageEntry.getValue();
            String tableName = value.getTableName().toUpperCase(Locale.ROOT);
            TableRecord tableRecord = tableRecordMap.get(tableName);
            if (tableRecord != null) {
                populateRows(value, tableRecord, value.getPossibleUpdateOrDeletePks());
                populateRows(value, tableRecord, value.getInsertPks());
            }
        }
    }

    /**
     * 校验没有操作同一条数据
     *
     * @return: void
     */
    public void validNotContainSameRecord() throws SQLException {
        Map<String, List<RecordImage>> recordImageMap = this.recordImageMap.values().stream().collect(Collectors.groupingBy(recordImage -> recordImage.getTableName().toUpperCase(Locale.ROOT), Collectors.toList()));
        for (Map.Entry<String, List<RecordImage>> recordEntry : recordImageMap.entrySet()) {
            List<RecordImage> recordImages = recordEntry.getValue();
            Set<String> keys = new HashSet<>();
            for (RecordImage recordImage : recordImages) {
                List<String> pkColumnsName = recordImage.getPkColumnsName();
                List<List<Object>> affectedPks = new ArrayList<>();
                if (recordImage.getSqlType() == SqlType.INSERT) {
                    affectedPks = recordImage.getInsertPks();
                } else if (recordImage.getSqlType() == SqlType.UPDATE || recordImage.getSqlType() == SqlType.DELETE) {
                    affectedPks = recordImage.getPossibleUpdateOrDeletePks();
                }
                for (List<Object> affectedPk : affectedPks) {
                    String key = defaultSqlRunner.assembleKey(pkColumnsName, affectedPk);
                    if (keys.contains(key)) {
                        throw new SQLException("表：" + recordEntry.getKey() + "，批处理对一条数据进行多次操作");
                    } else {
                        keys.add(key);
                    }
                }
            }

        }
    }

    public Collection<RecordImage> executeResult() {
        return recordImageMap.values().stream().map(recordImage -> {
            defaultSqlRunner.formatRecordImage(recordImage);
            return recordImage;
        }).collect(Collectors.toList());
    }

    public void populateRows(RecordImage recordImage, TableRecord tableRecord, List<List<Object>> pkValues) {
        if (CollectionUtil.isNotEmpty(pkValues)) {
            for (List<Object> pkValue : pkValues) {
                String pkKey = defaultSqlRunner.assembleKey(recordImage.getPkColumnsName(), pkValue);
                Map<String, List<Object>> rows = tableRecord.getRows();
                List<Object> row = rows.get(pkKey);
                if (row != null) {
                    recordImage.getAfterRecord().addRow(pkKey, row);
                }
            }
        }
        recordImage.getAfterRecord().setDelimiter(defaultSqlRunner.getDelimiter());
    }

    /**
     * <b>只要是批处理中带有insert，那么一定包含主键</b>
     *
     * @param tableName
     * @param statementMetaData
     * @return: com.auditlog.datasource.struct.RecordImage
     */
    private RecordImage handleInsert(String tableName, List<StatementMetaData> statementMetaData) throws SQLException {
        List<List<Object>> insertPkValues = new ArrayList<>();
        List<String> allColumns = new ArrayList<>();
        List<String> primaryKeys = new ArrayList<>();
        TableMeta tableMeta = null;
        int index = 0;
        for (StatementMetaData metaData : statementMetaData) {
            AbstractInsertSqlRunner sqlRunner = (AbstractInsertSqlRunner) metaData.getSqlRunner();
            if (index == 0) {
                tableMeta = sqlRunner.getTableMeta();
                allColumns = tableMeta.getColumnNamesByOrder();
                primaryKeys = tableMeta.getPrimaryKeyWithAlias("", true);
            }
            List<List<Object>> pkValues = sqlRunner.getPkValues();
            if (CollectionUtil.isEmpty(pkValues)) {
                pkValues = sqlRunner.extractPkValues();
            }
            insertPkValues.addAll(pkValues);
            index++;
        }
        RecordImage recordImage = RecordImage.builder().allColumnsWithAlias(allColumns).pkColumnsName(primaryKeys)
                .insertPks(insertPkValues).sqlType(SqlType.INSERT).tableName(tableName).tableMeta(tableMeta)
                .build();
        recordImageMap.put(this.getKey(tableName, SqlType.INSERT), recordImage);
        return recordImage;
    }


    /**
     * 更新删除处理
     *
     * @param tableName         表名
     * @param sqlType
     * @param statementMetaData
     * @return: com.auditlog.datasource.struct.RecordImage
     */
    private RecordImage handleDeleteOrUpdate(String tableName, SqlType sqlType, List<StatementMetaData> statementMetaData) throws SQLException {
        StringBuilder sqlBuilder = new StringBuilder();
        List<Integer> placeHolder = new ArrayList<>();
        AbstractSqlRunner sqlRunner = null;
        int index = 0;
        for (StatementMetaData metaData : statementMetaData) {
            if (index == 0) {
                sqlRunner = (AbstractSqlRunner) metaData.getSqlRunner();
            }
            String tableRecordsSql = metaData.getTableRecordsSql();
            sqlBuilder.append(tableRecordsSql);
            if (index != statementMetaData.size() - 1) {
                // 这里采用union的方式而不是用or拼接where条件是因为：表别名的处理比较麻烦。。。
                sqlBuilder.append(" union ");
            }
            List<Integer> placeHolderIndex = metaData.getPlaceHolderIndex();
            if (CollectionUtil.isNotEmpty(placeHolderIndex)) {
                placeHolder.addAll(placeHolderIndex);
            }
            index++;
        }
        AbstractSqlRunner.PlaceHolderSqlContext currentSqlContext = sqlRunner.getCurrentSqlContext(sqlBuilder.toString(), placeHolder, true);
        List<List<Object>> rows = defaultSqlRunner.executeParametersHolderSqlWithLockTable(currentSqlContext);
        List<Integer> indexList = IndexUtils.getIndexList(currentSqlContext.getPkColumnsNameWithNoAlias(), currentSqlContext.getSelectColumnsWithNoAlias());
        List<List<Object>> pkValues = defaultSqlRunner.getIndexValue4List(rows, indexList, false);
        TableRecord tableRecord = defaultSqlRunner.buildTableRecord(rows, currentSqlContext.getPkColumnsNameWithNoAlias(), currentSqlContext.getSelectColumnsWithNoAlias());
        RecordImage recordImage = RecordImage.builder().tableName(tableName).pkColumnsName(currentSqlContext.getPkColumnsNameWithNoAlias())
                .allColumnsWithAlias(currentSqlContext.getSelectColumnsWithNoAlias()).sqlType(sqlType).beforeRecord(tableRecord).possibleUpdateOrDeletePks(pkValues)
                .tableMeta(sqlRunner.getTableMeta())
                .build();
        recordImageMap.put(this.getKey(tableName, sqlType), recordImage);
        return recordImage;
    }

    private String getKey(String k1, SqlType sqlType) {
        return k1.toUpperCase(Locale.ROOT) + this.delimiter + sqlType;
    }
}
